"""Compute the gauge link variables U_μ from the gauge potential A_μ.

This module exposes a single function, `compute_U_from_A`, which takes the
gauge potential array and the gauge group name and returns the link
variables.  For U(1) groups the link variables are complex phases; for
SU(N) groups they are unitary matrices obtained via the matrix exponential.

The implementation relies only on NumPy and, if available, SciPy's matrix
exponential.  If SciPy is unavailable, a fallback eigendecomposition is used.
"""

import numpy as np

try:
    # Prefer SciPy's matrix exponential when available
    from scipy.linalg import expm  # type: ignore
except Exception:
    def expm(mat: np.ndarray) -> np.ndarray:
        """Fallback matrix exponential via eigendecomposition.

        For a square matrix ``mat`` compute ``exp(mat)`` by diagonalising it.
        This implementation assumes the matrix is diagonalizable and is
        sufficient for the small Hermitian matrices encountered in the
        simulations.  It allocates a new array on each call.
        """
        vals, vecs = np.linalg.eig(mat)
        exp_vals = np.exp(vals)
        return vecs @ np.diag(exp_vals) @ np.linalg.inv(vecs)


def compute_U_from_A(A: np.ndarray, gauge_group: str = 'U1') -> np.ndarray:
    """Compute link variables from the gauge potential array.

    Parameters
    ----------
    A : np.ndarray
        Gauge potential array.  For U(1) this is a one‑dimensional array of
        floats; for SU(N) it has shape (num_links, N, N).
    gauge_group : str, optional
        Name of the gauge group, case-insensitive.  Supported values are
        'U1', 'SU2', 'SU3', etc.  Defaults to 'U1'.

    Returns
    -------
    np.ndarray
        The link variables U_μ corresponding to the gauge group.  For U(1)
        this is a complex array of shape matching ``A``; for SU(N) it is
        a complex array of shape (num_links, N, N).
    """
    gauge = gauge_group.upper()
    if gauge == 'U1':
        # Scalar case: each link variable is just exp(i*A)
        return np.exp(1j * A)
    else:
        # SU(N) case: A has shape (num_links, N, N).  Compute the matrix
        # exponential for each link independently.
        U = np.empty_like(A, dtype=complex)
        for idx, A_mat in enumerate(A):
            U[idx] = expm(1j * A_mat)
        return U